
import datetime
import random
import time

import numpy as np
import torch
import mat4py
import faiss
from tqdm import tqdm, trange
from thop import profile
from thop import clever_format
import torch.nn.functional as F

from cbml_benchmark.data.evaluations import RetMetric, RecallCompute, RetMetricMap
from cbml_benchmark.utils.feat_extractor import feat_extractor
from cbml_benchmark.utils.freeze_bn import set_bn_eval
from cbml_benchmark.utils.metric_logger import MetricLogger

import matplotlib.pyplot as plt
import matplotlib as mlp
import os

def update_ema_variables(model,ema_model):
    alpha = 0.999
    for ema_param, param in zip(ema_model.parameters(), model.parameters()):
        ema_param.data.mul_(alpha).add_(1-alpha,param.data)

def do_train(
        cfg,
        model,
        train_loader,
        val_loader,
        train_val_loader,
        optimizer,
        scheduler,
        criterion,
        trans_criterion,
        criterion_aux,
        checkpointer,
        device,
        checkpoint_period,
        arguments,
        logger
):
    logger.info("Start training")
    meters = MetricLogger(delimiter="  ")
    max_iter = len(train_loader)
    input = torch.randn(1,3,cfg.INPUT.CROP_SIZE,cfg.INPUT.CROP_SIZE).to(device)
    flops, params = profile(model.backbone,inputs=(input,))
    flops, params = clever_format([flops, params], "%.3f")
    print(flops, params)

    start_iter = arguments["iteration"]
    best_iteration = -1
    best_recall = 0

    memory_feats = []
    memory_feats_t = []
    memory_label = []
    memory_iter = 0
    memory_flag = False

    start_training_time = time.time()
    end = time.time()

    for iteration, (images, targets, indexes) in enumerate(train_loader, start_iter):
        centroids3 = torch.randn(64, 512).to("cuda")

        if iteration % cfg.VALIDATION.VERBOSE == 0 or iteration == max_iter:
            model.eval()
            logger.info('Validation')

            labels = train_val_loader.dataset.label_list
            labels = np.array([int(k) for k in labels])
            localfeats3, localfeats, feats_mp, centers = feat_extractor(model, train_val_loader,logger=logger)

            ret_metric = RetMetric(feats=localfeats3, labels=labels)
            recall_curr = []
            recall_curr.append(ret_metric.recall_k(1))
            recall_curr.append(ret_metric.recall_k(2))
            recall_curr.append(ret_metric.recall_k(4))
            recall_curr.append(ret_metric.recall_k(8))
            print(recall_curr)

            ret_metric = RetMetric(feats=localfeats, labels=labels)
            recall_curr = []
            recall_curr.append(ret_metric.recall_k(1))
            recall_curr.append(ret_metric.recall_k(2))
            recall_curr.append(ret_metric.recall_k(4))
            recall_curr.append(ret_metric.recall_k(8))
            print(recall_curr)


            labels = val_loader.dataset.label_list
            labels = np.array([int(k) for k in labels])
            localfeats3, localfeats, feats_mp, centers = feat_extractor(model, val_loader, logger=logger)

            ret_metric = RetMetric(feats=localfeats3, labels=labels)
            recall_curr = []
            recall_curr.append(ret_metric.recall_k(1))
            recall_curr.append(ret_metric.recall_k(2))
            recall_curr.append(ret_metric.recall_k(4))
            recall_curr.append(ret_metric.recall_k(8))
            print(recall_curr)

            ret_metric = RetMetric(feats=localfeats, labels=labels)
            recall_curr = []
            recall_curr.append(ret_metric.recall_k(1))
            recall_curr.append(ret_metric.recall_k(2))
            recall_curr.append(ret_metric.recall_k(4))
            recall_curr.append(ret_metric.recall_k(8))
            print(recall_curr)


            if recall_curr[0] > best_recall:
                best_recall = recall_curr[0]
                best_iteration = iteration
                logger.info(f'Best iteration {iteration}: recall@1: {recall_curr[0]:.3f}')
                checkpointer.save(f"best_model")
            else:
                logger.info(f'Recall@1 at iteration {iteration:06d}: recall@1: {recall_curr[0]:.3f}')

        model.train()
        # model.apply(set_bn_eval)

        data_time = time.time() - end
        iteration = iteration + 1
        arguments["iteration"] = iteration

        scheduler.step()

        images = images.to(device)
        targets = torch.stack([target.to(device) for target in targets])

        # feats_all = model.headfeature(model.backbone(images))
        out3, feats_all, feats_center3, feats_center2 = model.backbone(images)
        # centroids3 = feats_center3.detach()


        feats_local2 = model.localhead2(feats_all)
        feats_local3 = model.localhead3(out3)

        feats_local2 = F.avg_pool2d(feats_local2, (1, 1), 1, 0).squeeze(2).squeeze(2)
        # feats_local2 = feats_local2.reshape(out3.size(0),-1)
        feats_local3 = F.avg_pool2d(feats_local3, (8, 8), 8, 0).squeeze(2).squeeze(2)

        # feats_local3 = feats_local3_1.reshape(out3.size(0),-1)
        # feats_local2 = F.normalize(feats_local2, p=2, dim=1)
        feats_local3 = F.normalize(feats_local3,p=2,dim=1)

        feats = model.finalhead(feats_local3)
        feats = F.normalize(feats,p=2,dim=1)
        # feats = feats_local2


        if iteration % 2 <= 1:
            loss3 = criterion(feats_local3, targets)
            loss_final = criterion(feats, targets)

            feats_center2 = F.normalize(feats_center2, p=2, dim=1)
            sim_center = torch.mm(feats_center2, feats_center2.permute(1, 0))
            sim_center = torch.abs(sim_center)

            loss_dict2 = torch.triu(sim_center, 1).sum() / (
                        feats_center2.size(0) * feats_center2.size(0) - feats_center2.size(0)) * 2.

            feats_center3 = F.normalize(feats_center3,p=2,dim=1)
            sim_center1 = torch.mm(feats_center3, feats_center3.permute(1, 0))
            sim_center = torch.abs(sim_center1)
            # sim_center = torch.pow(sim_center,2)

            # loss_dict3 = torch.norm(torch.triu(sim_center, 1), p=0) / (feats_center3.size(0)*feats_center3.size(0)-feats_center3.size(0))*2.
            loss_dict3 = torch.triu(sim_center, 1).sum() / (feats_center3.size(0)*feats_center3.size(0)-feats_center3.size(0))*2.
            # loss_dict3 = torch.norm(torch.triu(sim_center, 1),p=2) / (feats_center3.size(0) * feats_center3.size(0) - feats_center3.size(0)) * 2.
            # loss_dict3 = torch.triu(sim_center,1).max()
            # loss_dict3 = torch.norm(torch.triu(sim_center, 1), p='nuc') / feats_center3.size(0)

            loss_distill = trans_criterion(feats, feats_local3.detach()) # stop gradient on the second branch
            lamd = iteration / cfg.SOLVER.MAX_ITERS
            beta = 0.9
            # shared weights between two branches
            # loss = loss_final + 10.0*loss_dict3  + (1. - beta - (1.-beta)*lamd)*loss3 + ((1.-beta)*lamd+beta)*(loss_distill)
            loss = loss_final + 10.0*loss_dict3 + (1. - beta - (1. - beta) * lamd) * loss3 + \
                   1.0*((1. - beta) * lamd + beta) * (loss_distill)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        batch_time = time.time() - end
        end = time.time()
        meters.update(time=batch_time, data=data_time, loss_final = loss.item())

        eta_seconds = meters.time.global_avg * (max_iter - iteration)
        eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))

        if iteration % 20 == 0 or iteration == max_iter:
            logger.info(
                meters.delimiter.join(
                    [
                        "eta: {eta}",
                        "iter: {iter}",
                        "{meters}",
                        "lr: {lr:.6f}",
                        "max mem: {memory:.1f} GB",
                    ]
                ).format(
                    eta=eta_string,
                    iter=iteration,
                    meters=str(meters),
                    lr=optimizer.param_groups[0]["lr"],
                    memory=torch.cuda.max_memory_allocated() / 1024.0 / 1024.0 / 1024.0,
                )
            )

        if iteration % checkpoint_period == 0:
            checkpointer.save("model_{:06d}".format(iteration))

    total_training_time = time.time() - start_training_time
    total_time_str = str(datetime.timedelta(seconds=total_training_time))
    logger.info(
        "Total training time: {} ({:.4f} s / it)".format(
            total_time_str, total_training_time / (max_iter)
        )
    )

    logger.info(f"Best iteration: {best_iteration :06d} | best recall {best_recall} ")

def do_test(
        model,
        val_loader,
        logger
):
    logger.info("Start testing")
    model.eval()
    logger.info('test')

    labels = val_loader.dataset.label_list
    labels = np.array([int(k) for k in labels])
    feats = feat_extractor(model, val_loader, logger=logger)

    ret_metric = RetMetric(feats=feats, labels=labels)
    recall_curr = []
    recall_curr.append(ret_metric.recall_k(1))
    recall_curr.append(ret_metric.recall_k(2))
    recall_curr.append(ret_metric.recall_k(4))
    recall_curr.append(ret_metric.recall_k(8))

    print(recall_curr)
